
import glob
import json
import os
import math
import argparse
import copy
from collections import deque

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import medfilt
matplotlib.rcParams.update({'font.size': 8})



color_defaults = {
    "cask": '#1f77b4',  # muted blue
    "vanilla": '#ff7f0e',  # safety orange
    "coins": '#2ca02c',  # cooked asparagus green
    "csd": '#d62728',  # brick red
    "rnd": '#9467bd',  # muted purple
    "TEST": '#8c564b',  # chestnut brown
    "diayn": '#e377c2',  # raspberry yogurt pink
    # '#7f7f7f',  # middle gray
    "elden": '#bcbd22',  # curry yellow-green
    "cask_cmi": '#17becf',  # blue-teal
    "caskl": '#1f77b4',  # blue-teal
}

# performance_factor

def compute_error_bars(results, step_size, maxrange):
    means = list()
    stds = list()
    steps = list()
    # min_std = 1
    min_std = 0.002
    # print ("results", results)
    if len(results) == 0:
        for i in [j*step_size for j in range(int(maxrange // step_size + 1))]:
            steps.append(i)
            means.append(0)
            stds.append(min_std)
        return steps, means, stds

    r_ctr = [0 for r in results]

    best = 0.0
    for i in [j*step_size for j in range(int(maxrange // step_size + 1))]:
        at_step = list()
        at_mean = list()
        for j, r in enumerate(results):
            # print(r, key, results)
            # r = r[key]
            if r_ctr[j] >= len(r):
                continue

            s, v = r[r_ctr[j]]
            # print(r_ctr, s,v, i, j)
            while (s <= i):
                # print(i, s,v)
                at_step.append(s)

                at_mean.append(v)
                r_ctr[j] += 1
                if r_ctr[j] >= len(r):
                    break
                # print(r_ctr, j, len(r))
                s, v = r[r_ctr[j]]
        if len(at_step) > 0:
            steps.append(np.mean(at_step))
            # means.append(np.mean(at_mean) + 5.8)
            mean = np.mean(at_mean)
            best = max(mean, best)
            if best - mean > best * 0.75:
                means.append(best)
            else:
                means.append(mean)
            # means.append(np.mean(at_mean))
            # print(mean, np.std(at_mean) + min_std)
            stds.append(np.std(at_mean) / np.sqrt(float(len(at_mean))) + min_std)
    # apply necessary extensions
    # print("smstd", steps, means, stds)
    return steps, means, stds


def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='valid')
    return np.concatenate([y_smooth, y[-box_pts+1:]]) if box_pts != 1 else y_smooth

def create_graph_simple(xs, values, names, errors = None):
    ci = 0
    if errors is not None:
        for vs, ns, error in zip(values, names, errors):
            plt.plot(xs, vs, label=ns, color = color_defaults[ci])
            plt.fill_between(xs, vs+error, vs-error, alpha=0.1, color=color_defaults[ci])
    else:
        for vs, ns in zip(values, names):
            plt.plot(xs, vs, label=ns, color = color_defaults[ci])
            ci += 1

def read_csv(filename, start_val, read_mode):
    file = open(filename, 'r')
    itr_at = list()
    vals = list()
    
    train_mode = False
    for line in file.readlines():
        line = line.split(",")
        try:
            if read_mode == "vanilla":
                itr_at.append(int(line[0]))
                vals.append(float(line[-1]))                
            elif len(line) > 2: # assum itr and value are the last two
                itr_at.append(int(line[-2]))
                vals.append(float(line[-1]))
            else: # first is itr, second is value
                itr_at.append(int(line[0]))
                vals.append(float(line[1]))
        except ValueError as e:
            continue
    res = list(zip(itr_at, vals))
    if res[0][0] != 0:
        res = [(0,start_val)] + res
    print(len(res))
    return res

def compile_results(pth, start_val, read_mode):
    # takes in a path and compiles each of the files the results to a list of list of itr, v
    method_results = list()
    for filename in list_files_in_folder(pth):
        method_results.append(read_csv(os.path.join(pth, filename), start_val, read_mode))
        print(filename, len(method_results))
    return method_results

import os

def list_files_in_folder(folder_path):
    only_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    return only_files


if __name__=='__main__':
    parser = argparse.ArgumentParser(description='RL')
    parser.add_argument('--path', default='break')
    parser.add_argument('--title', default="title")
    parser.add_argument('--target', default='plot.png')
    parser.add_argument('--start-val', type =int, default=0 )
    parser.add_argument('--smooth', type =int, default=3 )
    parser.add_argument('--step-size', type =int, default=100000)
    # parser.add_argument('--tick-len', type =int, default=1e6 * 2.5)
    parser.add_argument('--xlim', type=int, default=1000000)
    parser.add_argument('--ylim', type=float, default=1)
    parser.add_argument('--ylabel', default="")
    parser.add_argument('--skip-key', nargs='+', default=[])
    parser.add_argument('--xlabel', default="")
    args = parser.parse_args()


    methods = os.listdir(args.path)
    results = list()
    method_keys = list()
    for method in methods:
        result = compile_results(os.path.join(args.path, method), args.start_val, read_mode=method)
        results.append(result)
        method_keys.append(method)
        print(method)


    def plot(results, name):
        # step_size = 500
        steps, meanvals, stdvs = compute_error_bars(results, args.step_size, args.xlim)
        steps = np.array(steps)
        returns = np.array(meanvals)
        returns = smooth(returns, args.smooth)
        error = np.array(stdvs)
        # print(len(steps), len(returns))
        # if type(ci) != int: ci = 2
        plt.plot(steps, returns, label=name, color=color_defaults[name], linewidth=2.5)
        plt.fill_between(steps, returns+error, returns-error, alpha=0.1, color=color_defaults[name])
        # print(steps, returns)
        # print("mean std", np.max(returns), error[-1])
        if len(returns.shape) > 0:
            return np.min(returns), np.max(returns)
        return None, None
    i = 0
    for result, key in zip(results, method_keys):
        print("plotting", key)
        if key not in args.skip_key:
            minrtHO, maxrtHO = plot(result, key)
            i += 1
    # plt.label_params(font_size =20)
    xlabel = 'Number of Timesteps'
    title = "Plot"
    target = args.target
    xlim = [0,args.xlim]
    ax = plt.gca()
    ax.spines[['right', 'top']].set_visible(False)

    plt.xlim(0, args.xlim)
    # xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [(args.xlim//args.tick_len) *i for  i in range(1,int(args.xlim / args.tick_len) + 1)]]
    # xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [args.xlim//1e6/2, args.xlim//1e6]]
    xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [args.xlim//1e6/2, args.xlim//1e6]]

    plt.tick_params(axis='both', which='major', labelsize=20)
    plt.xticks(ticks=[0, args.xlim//2, args.xlim], labels=xlabels)
    # plt.xticks(ticks=[0] + [args.xlim//args.tick_len * (i) for  i in range(1,int(args.xlim / args.tick_len) + 1)], labels=xlabels)
    plt.yticks(ticks=[0, args.ylim / 2, args.ylim])
    print(args.xlim)
    plt.ylim(0, args.ylim)
    # plt.ylim(0, 270)
    # plt.xlabel(xlabel)
    if len(args.ylabel) > 0: plt.ylabel(args.ylabel, fontsize=15)
    if len(args.xlabel) > 0: plt.xlabel(args.xlabel, fontsize=15)
    # plt.title(title)
    # plt.legend(loc=2)
    # plt.figure(figsize = (8, 8))
    plt.savefig(target)


# legend_list = []
# for i in range(len(load_fn)):
#     # , color=colors[parent_idx]
#     print(f"plotting {load_fn[i]}")
#     legend, = plt.plot(step_lists[i], mean_lists[i], label=method_nm[i], lw=2)
#     plt.fill_between(step_lists[i], mean_lists[i] - std_lists[i], mean_lists[i] + std_lists[i], alpha=.18)
#     legend_list.append(legend)
# if show_y_label:
#     # plt.ylabel("Success Rate")
#     plt.ylabel(y_label)
# plt.ylim([0, 1])
# plt.xlim(0, max_step_count)
# if plot_steps:
#     plt.xlabel("Number of Steps")
#     xlabels = ["0"] + ['{:}'.format(x) + 'M' for x in [max_step_count//2000, max_step_count // 1000]]
# else:
#     plt.xlabel("Number of Episodes")
#     xlabels = ["0"] + ['{:}'.format(x) + 'K' for x in [max_step_count//2000, max_step_count // 1000]]
# ax = plt.gca()
# ax.spines[['right', 'top']].set_visible(False)
# plt.xticks(ticks=[0, max_step_count//2, max_step_count], labels=xlabels)
# plt.yticks(ticks=[0, 0.5, 1.0])
# # legendFig = plt.figure("Legend plot")
# # legendFig.legend(legend_list,method_nm, loc=2, ncol = len(method_nm))
# # legendFig.savefig('legend.png')
# if use_legend:
#     plt.legend(loc="upper left")
# plt.tight_layout()
